Variational Inference with Implicit Approximate Inference Models (WIP Pt. 9)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
Using TensorFlow backend.
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 140244148646072 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)140244148249264 dense_1: Denseinput:output:(None, 2)(None, 10)140244148646072->140244148249264 140244148647304 dense_2: Denseinput:output:(None, 10)(None, 20)140244148249264->140244148647304 140244148649320 logit: Denseinput:output:(None, 20)(None, 1)140244148647304->140244148649320 140244149096912 activation_1: Activationinput:output:(None, 1)(None, 1)140244148649320->140244149096912
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.67780220508575439, 0.60000002384185791]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 140244148376016 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)140244147730808 dense_3: Denseinput:output:(None, 3)(None, 10)140244148376016->140244147730808 140244147730864 dense_4: Denseinput:output:(None, 10)(None, 20)140244147730808->140244147730864 140244148842448 dense_5: Denseinput:output:(None, 20)(None, 2)140244147730864->140244148842448
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(200, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(200, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
metrics
Out[37]:
[0.68443809032440184, 0.40500000000000003]
In [38]:
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
In [39]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

metrics_plots = {k:ax1.plot([], label=k)[0] 
                 for k in ['loss']} # discriminator.metrics_names}

ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [40]:
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):

    # Single training epoch
    
    for step in tnrange(steps_per_epoch, unit='step', leave=False):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    # Plot Metrics
        
    metrics_dict = dict(zip(discriminator.metrics_names, metrics))

    for metric in metrics_plots:
        metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(), 
                                                  epoch_num))    
        metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(), 
                                                  metrics_dict[metric]))
        metrics_plots[metric].set_label('{} ({:.2f})' \
                                        .format(metric, 
                                                metrics_dict[metric]))
    
    ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
    ax1.legend(loc='upper left')

    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(**metrics_dict)

    return list(metrics_plots.values())
In [41]:
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate` 
# function that encapsulates the logic of single
# training epoch. Has benefit of producing 
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=PRETRAIN_EPOCHS,
                         interval=200, # 5 fps
                         blit=True)

    anim_html5_video = anim.to_html5_video()
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

In [42]:
HTML(anim_html5_video)
Out[42]:
In [43]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [44]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [45]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [46]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

metrics_dict = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**metrics_dict), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [47]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [48]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[48]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [49]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[49]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [50]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [51]:
sess = K.get_session()
In [52]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[52]:
True
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Reweight likelihood term!

In [54]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(2.*log_likelihood-kl_estimate, axis=-1)

    return elbo
In [55]:
elbo = make_elbo(ratio_estimator)
In [56]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [57]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [58]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [59]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [60]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[60]:
<tf.Tensor 'concat:0' shape=(200, 3) dtype=float32>
In [61]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[61]:
-4.8959107
In [62]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/200 [===>..........................] - ETA: 0s
Out[62]:
4.8959110260009764

Adversarial Training

In [63]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

global_epoch = 0

loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')

ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [64]:
def train_animate(epoch_num, prog_bar, batch_size=200, 
                  steps_per_epoch=15):

    global global_epoch, loss_plot_inference, loss_plot_discrim
    
    # Single training epoch

    ## Ratio estimator training
        
    set_trainable(discriminator, True)

    for _ in tnrange(3*50, unit='step', desc='discriminator', 
                     leave=False):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

    metrics_dict_discrim = dict(zip(discriminator.metrics_names, 
                                    np.atleast_1d(metrics_discrim)))
    
    ## Inference model training
    
    set_trainable(ratio_estimator, False)

    y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))

    for _ in tnrange(1, unit='step', desc='inference', leave=False):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, y_tiled)
        
    metrics_dict_inference = dict(zip(inference.metrics_names, 
                                      np.atleast_1d(metrics_inference)))

    global_epoch += 1
    
    # Plot Loss
 
    loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
                                            global_epoch))
    loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(), 
                                            metrics_dict_inference['loss']))

    loss_plot_inference.set_label('inference ({:.2f})' \
                                  .format(metrics_dict_inference['loss']))

    loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
                                          global_epoch))
    loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
                                          metrics_dict_discrim['loss']))

    loss_plot_discrim.set_label('discriminator ({:.2f})' \
                                  .format(metrics_dict_discrim['loss']))
    
    ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
    ax1.legend(loc='upper left')
    
    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
                         loss_discriminator=metrics_dict_discrim['loss'])

    return loss_plot_inference, loss_plot_discrim
In [65]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[65]:
In [66]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[66]:
In [67]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[67]:
In [68]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[68]:

Evaluating the model

In [69]:
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
In [70]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [71]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax1.set_xlabel('$w_1$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

ax2.contourf(w1, w2, np.sum(llhs, axis=2), 
             cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [72]:
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
In [73]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax1.scatter(*inference.predict(eps[::10]).T, 
            s=4.**2, alpha=.6, cmap='coolwarm_r')

ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

sns.kdeplot(*inference.predict(eps).T,
            cmap='magma', ax=ax2)

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 8)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
Using TensorFlow backend.
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 139996050319736 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)139996049881464 dense_1: Denseinput:output:(None, 2)(None, 10)139996050319736->139996049881464 139996049882416 dense_2: Denseinput:output:(None, 10)(None, 20)139996049881464->139996049882416 139996050321304 logit: Denseinput:output:(None, 20)(None, 1)139996049882416->139996050321304 139996050741624 activation_1: Activationinput:output:(None, 1)(None, 1)139996050321304->139996050741624
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.62235820293426514, 0.60000002384185791]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 139996049667352 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)139996049443528 dense_3: Denseinput:output:(None, 3)(None, 10)139996049667352->139996049443528 139996049443416 dense_4: Denseinput:output:(None, 10)(None, 20)139996049443528->139996049443416 139996049667520 dense_5: Denseinput:output:(None, 20)(None, 2)139996049443416->139996049667520
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(200, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(200, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
metrics
Out[37]:
[0.69308857917785649, 0.47749999999999998]
In [38]:
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
In [39]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

metrics_plots = {k:ax1.plot([], label=k)[0] 
                 for k in ['loss']} # discriminator.metrics_names}

ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [40]:
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):

    # Single training epoch
    
    for step in tnrange(steps_per_epoch, unit='step', leave=False):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    # Plot Metrics
        
    metrics_dict = dict(zip(discriminator.metrics_names, metrics))

    for metric in metrics_plots:
        metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(), 
                                                  epoch_num))    
        metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(), 
                                                  metrics_dict[metric]))
        metrics_plots[metric].set_label('{} ({:.2f})' \
                                        .format(metric, 
                                                metrics_dict[metric]))
    
    ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
    ax1.legend(loc='upper left')

    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(**metrics_dict)

    return list(metrics_plots.values())
In [41]:
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate` 
# function that encapsulates the logic of single
# training epoch. Has benefit of producing 
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=PRETRAIN_EPOCHS,
                         interval=200, # 5 fps
                         blit=True)

    anim_html5_video = anim.to_html5_video()
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

In [42]:
HTML(anim_html5_video)
Out[42]:
In [43]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [44]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [45]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [46]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

metrics_dict = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**metrics_dict), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [47]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [48]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[48]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [49]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[49]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [50]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [51]:
sess = K.get_session()
In [52]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[52]:
True
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [54]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [55]:
elbo = make_elbo(ratio_estimator)
In [56]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [57]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [58]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [59]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [60]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[60]:
<tf.Tensor 'concat:0' shape=(200, 3) dtype=float32>
In [61]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[61]:
-2.5127466
In [62]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/200 [===>..........................] - ETA: 0s
Out[62]:
2.5127464294433595

Adversarial Training

In [63]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))

global_epoch = 0

loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')

ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')

ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [64]:
def train_animate(epoch_num, prog_bar, batch_size=200, 
                  steps_per_epoch=15):

    global global_epoch, loss_plot_inference, loss_plot_discrim
    
    # Single training epoch

    ## Ratio estimator training
        
    set_trainable(discriminator, True)

    for _ in tnrange(3*50, unit='step', desc='discriminator', 
                     leave=False):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

    metrics_dict_discrim = dict(zip(discriminator.metrics_names, 
                                    np.atleast_1d(metrics_discrim)))
    
    ## Inference model training
    
    set_trainable(ratio_estimator, False)

    y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))

    for _ in tnrange(1, unit='step', desc='inference', leave=False):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, y_tiled)
        
    metrics_dict_inference = dict(zip(inference.metrics_names, 
                                      np.atleast_1d(metrics_inference)))

    global_epoch += 1
    
    # Plot Loss
 
    loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
                                            global_epoch))
    loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(), 
                                            metrics_dict_inference['loss']))

    loss_plot_inference.set_label('inference ({:.2f})' \
                                  .format(metrics_dict_inference['loss']))

    loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
                                          global_epoch))
    loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
                                          metrics_dict_discrim['loss']))

    loss_plot_discrim.set_label('discriminator ({:.2f})' \
                                  .format(metrics_dict_discrim['loss']))
    
    ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
    ax1.legend(loc='upper left')
    
    ax1.relim()
    ax1.autoscale_view()
    
    # Contour Plot
    
    ax2.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
    ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    ax2.set_xlabel('$w_1$')
    ax2.set_ylabel('$w_2$')

    ax2.set_xlim(w_min, w_max)
    ax2.set_ylim(w_min, w_max)
    
    # Progress Bar Updates
    
    prog_bar.update()
    prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
                         loss_discriminator=metrics_dict_discrim['loss'])

    return loss_plot_inference, loss_plot_discrim
In [65]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[65]:
In [66]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[66]:
In [67]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[67]:
In [68]:
with tqdm_notebook(total=50, 
                   unit='epoch', leave=True) as prog_bar:

    anim = FuncAnimation(fig, 
                         train_animate,
                         fargs=(prog_bar,),
                         frames=50,
                         interval=200, # 5 fps
                         blit=True)
    
    anim_html5_video = anim.to_html5_video()
    
HTML(anim_html5_video)
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.
Widget Javascript not detected.  It may not be installed or enabled properly.

Out[68]:

Evaluating the model

In [81]:
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
In [82]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [85]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax1.set_xlabel('$w_1$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

ax2.contourf(w1, w2, np.sum(llhs, axis=2), 
             cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')

ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)

plt.show()
In [101]:
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
In [105]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))

ax1.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax1.scatter(*inference.predict(eps[::10]).T, 
            s=4.**2, alpha=.6, cmap='coolwarm_r')

ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')

ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)

sns.kdeplot(*inference.predict(eps).T,
            cmap='magma', ax=ax2)

plt.show()

Variational Inference with Implicit Models (forked from @fhuszar)

In [19]:
%matplotlib inline
from matplotlib import pylab as plt
import theano
from theano import tensor as T
import numpy as np
In [20]:
#create a regular grid in weight space for visualisation
wmin = -5
wmax = 5
wrange = np.linspace(wmin,wmax,300)
w = np.repeat(wrange[:,None],300,axis=1)
w = np.concatenate([[w.flatten()],[w.T.flatten()]])
In [21]:
w_grid = np.dstack(np.mgrid[wmin:wmax:300j,wmin:wmax:300j])
w_grid_flat = w_grid.reshape(300*300,2).T
In [22]:
w_grid.shape
Out[22]:
(300, 300, 2)
In [23]:
w[...,:5]
Out[23]:
array([[-5.        , -5.        , -5.        , -5.        , -5.        ],
       [-5.        , -4.96655518, -4.93311037, -4.89966555, -4.86622074]])
In [ ]:
 
In [24]:
np.all(w.T.reshape(300,300,2)==w_grid)
Out[24]:
True
In [25]:
np.all(w_grid_flat==w)
Out[25]:
True
In [26]:
prior_variance = 2
logprior = -(w**2).sum(axis=0)/2/prior_variance
plt.contourf(wrange, wrange, logprior.reshape(300,300), cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
Out[26]:
(-5, 5)
In [27]:
prior_variance = 2
logprior = -(w**2).sum(axis=0)/2/prior_variance
plt.contourf(wrange, wrange, logprior.reshape(300,300), cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
In [28]:
#generating a toy dataset with three manually selected observations
from scipy.stats import logistic
sigmoid = logistic.cdf
logsigmoid = logistic.logcdf

def likelihood(w,x,b=0,y=1):
    return logsigmoid(y*(np.dot(w.T,x) + b)).flatten()


x1 = np.array([[1.5],[1]])
x2 = np.array([[-1.5],[1]])
x3 = np.array([[0.5],[-1]])

y1=1
y2=1
y3=-1

llh1 = likelihood(w, x1, y=y1) 
llh2 = likelihood(w, x2, y=y2) 
llh3 = likelihood(w, x3, y=y3) 
In [29]:
#calculating unnormalised log posterior
#this is only for illustration
logpost = llh1 + llh2 + llh3 + logprior
In [30]:
#plotting the real log posterior
#the red dots show the three datapoints, the small line segments shows the direction
#in which the corresponding label shifts the posterior. Positive datapoints shift the 
# posterior away from zero in the direction of the datapoint, negative datapoints shift
# away from zero, in the opposite direction.

plt.contourf(wrange,
             wrange,
             np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])

plt.plot(x1[0],x1[1],'.r')
plt.plot([x1[0],x1[0]*(1+0.2*y1)],[x1[1],x1[1]*(1+0.2*y1)],'r-')

plt.plot(x2[0],x2[1],'.r')
plt.plot([x2[0],x2[0]*(1+0.2*y2)],[x2[1],x2[1]*(1+0.2*y2)],'r-')

plt.plot(x3[0],x3[1],'.r')
plt.plot([x3[0],x3[0]*(1+0.2*y3)],[x3[1],x3[1]*(1+0.2*y3)],'r-');

Fitting an approximate posterior

This part is for the actual GAN stuff. Here we define the generator and the discriminator networks in Lasagne, and code up the two loss functions in theano.

In [31]:
from lasagne.utils import floatX

from lasagne.layers import (
    InputLayer,
    DenseLayer,
    NonlinearityLayer)
from lasagne.nonlinearities import sigmoid

#defines a 'generator' network
def build_G(input_var=None, num_z = 3):
    
    network = InputLayer(input_var=input_var, shape=(None, num_z))
    
    network = DenseLayer(incoming = network, num_units=10)
    
    network = DenseLayer(incoming = network, num_units=20)
    
    network = DenseLayer(incoming = network, num_units=2, nonlinearity=None)
    
    return network

#defines the 'discriminator network'
def build_D(input_var=None):

    network = InputLayer(input_var=input_var, shape = (None, 2))
    
    network = DenseLayer(incoming = network, num_units=10)
    
    network = DenseLayer(incoming = network, num_units=20)
    
    network = DenseLayer(incoming = network, num_units=1, nonlinearity=None)
    
    normalised = NonlinearityLayer(incoming = network, nonlinearity = sigmoid)
    
    return { 'unnorm':network, 'norm':normalised }
/home/tiao/.virtualenvs/implicit/lib/python3.5/site-packages/theano/tensor/signal/downsample.py:6: UserWarning: downsample module has been moved to the theano.tensor.signal.pool module.
  "downsample module has been moved to the theano.tensor.signal.pool module.")
In [32]:
from lasagne.layers import get_output, get_all_params
from theano.printing import debugprint
from lasagne.updates import adam
from theano.tensor.shared_randomstreams import RandomStreams

#variables for input (design matrix), output labels, GAN noise variable, weights
x_var = T.matrix('design matrix')
y_var = T.vector('labels')
z_var = T.matrix('GAN noise')
w_var = T.matrix('weights')

#theano variables for things like batchsize, learning rate, etc.
batchsize_var = T.scalar('batchsize', dtype='int32')
prior_variance_var = T.scalar('prior variance')
learningrate_var = T.scalar('learning rate')

#random numbers for sampling from the prior or from the GAN
srng = RandomStreams(seed=1337)
z_rnd = srng.normal((batchsize_var,3))
prior_rnd = srng.normal((batchsize_var,2))

#instantiating the G and D networks
generator = build_G(z_var)
discriminator = build_D()

#these expressions are random samples from the generator and the prior, respectively
samples_from_grenerator = get_output(generator, z_rnd)
samples_from_prior = prior_rnd*T.sqrt(prior_variance_var)

#discriminator output for synthetic samples, both normalised and unnormalised (after/before sigmoid)
D_of_G = get_output(discriminator['norm'], inputs=samples_from_grenerator)
s_of_G = get_output(discriminator['unnorm'], inputs=samples_from_grenerator)

#discriminator output for real samples from the prior
D_of_prior = get_output(discriminator['norm'], inputs=samples_from_prior)

#loss of discriminator - simple binary cross-entropy loss
loss_D = -T.log(D_of_G).mean() - T.log(1-D_of_prior).mean()

#log likelihood for each synthetic w sampled from the generator
log_likelihood = T.log(
    T.nnet.sigmoid(
        (y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * samples_from_grenerator.dimshuffle('x', 1, 0))).sum(1)
    )
).sum(0).mean()

#loss for G is the sum of unnormalised discriminator output and the negative log likelihood
loss_G = s_of_G.mean() - log_likelihood

#compiling theano functions:
evaluate_generator = theano.function(
    [z_var],
    get_output(generator),
    allow_input_downcast=True
)

sample_generator = theano.function(
    [batchsize_var],
    samples_from_grenerator,
    allow_input_downcast=True,
)

sample_prior = theano.function(
    [prior_variance_var, batchsize_var],
    samples_from_prior,
    allow_input_downcast=True
)

params_D = get_all_params(discriminator['norm'], trainable=True)

updates_D = adam(
    loss_D,
    params_D,
    learning_rate = learningrate_var
)

train_D = theano.function(
    [learningrate_var, batchsize_var, prior_variance_var],
    loss_D,
    updates = updates_D,
    allow_input_downcast = True
)

params_G = get_all_params(generator, trainable=True)

updates_G = adam(
    loss_G,
    params_G,
    learning_rate = learningrate_var
)

train_G = theano.function(
    [x_var, y_var, learningrate_var, batchsize_var],
    loss_G,
    updates = updates_G,
    allow_input_downcast = True
)

evaluate_discriminator = theano.function(
    [w_var],
    get_output([discriminator['unnorm'],discriminator['norm']],w_var),
    allow_input_downcast = True
)

#this is to evaluate the log-likelihood of an arbitrary set of w
llh_for_w = T.nnet.sigmoid((y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * w_var.dimshuffle('x', 1, 0))).sum(1))

evaluate_loglikelihood = theano.function(
        [x_var, y_var, w_var],
        llh_for_w,
        allow_input_downcast = True
    )
WARNING (theano.tensor.blas): We did not found a dynamic library into the library_dir of the library we use for blas. If you use ATLAS, make sure to compile it with dynamics library.
In [33]:
plt.figure(figsize=(16,8))


plt.subplot(1,2,1)
plt.contourf(wrange, wrange, evaluate_discriminator(w.T)[0].reshape(300,300).T)

W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'g.')

W = sample_prior(prior_variance, 100)
plt.plot(W[:,0],W[:,1],'r.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('estimated log density ratio $\Phi^{-1}(D)$')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);

plt.title('log likelihood');
In [9]:
#checking that theano and numpy give the same likelihoods
import seaborn as sns
sns.set_context('poster')

X_train = np.concatenate([x1,x2,x3],axis=1).T
y_train = np.array([y1,y2,y3])
llh_theano = evaluate_loglikelihood(X_train, y_train, w.T)

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.contourf(wrange, wrange ,np.log(llh_theano).sum(0).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('theano loglikelihood')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('numpy loglikelihood')

assert np.allclose(llh1+llh2+llh3,np.log(llh_theano).sum(0))
In [10]:
batchsize = 200
KARPATHY_CONSTANT = 3e-4
learning_rate = KARPATHY_CONSTANT*10

prior_variance = 2
# train discriminator for some time before starting iterative process
%timeit -n 300 train_D(learning_rate, batchsize, prior_variance)
# print initial values of training errors
print (train_D(0, 100, prior_variance))
print (train_G(X_train, y_train, 0, batchsize))
for i in range(200):
    %timeit -n 50 train_D(learning_rate, batchsize, prior_variance)
    print (train_D(0, 100, prior_variance))
    print (train_G(X_train, y_train, 0, 100))
    train_G(X_train, y_train, learning_rate, batchsize)
756 µs ± 271 µs per loop (mean ± std. dev. of 7 runs, 300 loops each)
0.33257354666094224
4.996102628755116
490 µs ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.2119455387013967
4.1202309895028835
554 µs ± 110 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.3581621228462506
4.707070363772413
525 µs ± 50.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.21017178628830308
4.276304144298293
578 µs ± 103 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.40397309168350565
4.372633122086429
1.03 ms ± 469 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.46383786347060996
4.2827275263065765
532 µs ± 81.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.4742954938839633
4.3077686017458054
531 µs ± 84.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.4121499447379206
4.397708586401626
593 µs ± 156 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.44319494364792195
3.980595081088458
1.08 ms ± 449 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.42927350789948465
3.7900577311949553
852 µs ± 403 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.46957103677027356
3.710435643364359
806 µs ± 297 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.3602928390303859
3.4830954212562526
1.08 ms ± 245 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.40927156952073895
3.4364983784929883
761 µs ± 94.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.50816115146877
3.599252549464353
531 µs ± 41.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6243425998036307
3.3147001872257293
885 µs ± 255 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5214014031521584
3.2957678425189454
509 µs ± 51.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5777377856471808
3.3931108838380686
527 µs ± 66.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5769061981295343
2.9636664857310944
533 µs ± 40.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5919845798230468
3.0980678544245177
585 µs ± 120 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5221136104884718
2.9536828849957484
475 µs ± 11.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5899299184472118
3.018050123719444
514 µs ± 70.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5763176714698327
2.7836889692763553
515 µs ± 67.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6867567169306564
2.7724327113426255
846 µs ± 221 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6660490357124809
2.6334797044549028
503 µs ± 14.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6661692478486261
2.798222159866464
577 µs ± 79.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6284425682323345
2.45941002566091
616 µs ± 135 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7538949350668582
2.5561553966818646
857 µs ± 326 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6784818173772513
2.5707759109406383
662 µs ± 184 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6106598836310618
2.4423542734158934
504 µs ± 53 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7322285427403872
2.486232874363383
551 µs ± 82 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.5639550073386318
2.6146086412137772
979 µs ± 542 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7264087412017558
2.4332230875462812
482 µs ± 16.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7283933854707818
2.3681987399338826
530 µs ± 73.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6657642175741061
2.3105561541710573
1.03 ms ± 382 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7455752044114734
2.3375515479136095
509 µs ± 15.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.695036638914291
2.241469415845449
612 µs ± 191 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8288253950300575
2.1713170188630624
540 µs ± 41.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7325451695034837
2.3926006731706173
528 µs ± 39.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6537477437558776
2.221071531153876
474 µs ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6594337612422132
2.0842268725820223
548 µs ± 75.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7937575616200123
2.287965474601517
639 µs ± 181 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7824615530138115
2.2032675608338157
498 µs ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7041354493336556
2.223044383372638
499 µs ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7809340654032331
2.2406197901191884
552 µs ± 104 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8560787560241433
2.242268627133387
544 µs ± 80.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8250840009584233
2.0697757387240405
485 µs ± 16.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8030169872590546
2.168738405634988
497 µs ± 20.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7411801737133148
2.073490427374885
481 µs ± 17.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.6817005099920493
2.030122806263039
500 µs ± 27.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7137601683021236
2.0777852018070595
485 µs ± 14.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8651488860181719
2.0226483968095486
552 µs ± 81.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.835908020892067
2.0074897300615877
477 µs ± 17.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.758321051598106
1.9466429535011147
495 µs ± 27.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7774988089743448
2.0646111589711773
529 µs ± 52.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9320844472032326
2.059555610186674
1.03 ms ± 545 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8995300497945045
2.096498881611496
509 µs ± 34.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8120412541290547
2.0689781803221026
542 µs ± 70.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8022209799129976
2.006215299988812
505 µs ± 41.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9861490734907152
2.08059183752742
680 µs ± 176 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0280314462199212
1.990792685262567
874 µs ± 198 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9207858693798733
2.0218614795917036
663 µs ± 256 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9208536630021291
2.058661662973811
723 µs ± 181 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7806616908852124
2.092379434684223
482 µs ± 14.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7260779512186379
2.1071780917151113
570 µs ± 157 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8798079764013904
1.8842698809942315
540 µs ± 29.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7932729045583661
2.0850770570680015
546 µs ± 55.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8061788863252047
2.040032451204915
470 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8073895158455865
1.975324924853765
512 µs ± 39.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8935399317474548
2.0107243214818857
500 µs ± 20.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9027509732177925
1.938966456452621
562 µs ± 57.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9321444061436406
2.000069069107975
473 µs ± 8.21 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8318318871623693
1.9559127147670219
626 µs ± 178 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8517743244973623
1.958692968227925
480 µs ± 9.61 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9480755619110519
2.0338699768603945
1.06 ms ± 508 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8165538496544358
2.0816034247398614
488 µs ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0169935523291809
2.002750951769813
601 µs ± 83.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8609770779991297
1.8982912711475097
607 µs ± 62.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8046362052484096
1.9086013477667225
553 µs ± 43.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9325364646910803
1.9109967613404208
511 µs ± 53.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0304139762238786
1.9654525015937858
505 µs ± 57.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0540401463614057
1.9582530556159827
562 µs ± 77.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.914125377935025
1.8530085305028012
472 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9448266113197814
1.9636818919264836
528 µs ± 70.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.963835996295888
1.894341290727306
797 µs ± 237 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8615783954534084
1.9528007843738875
472 µs ± 20 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7910335704428098
1.9388852726528625
501 µs ± 58.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8739924169364888
1.9559208529674579
513 µs ± 113 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9051278386689579
1.8909948941135415
841 µs ± 81.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8813494975055529
1.9817322193447398
1.1 ms ± 363 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9470555429508094
1.9069496146627483
904 µs ± 167 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7877545237634651
1.9702616998910425
972 µs ± 128 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9205211070750159
2.014370612501838
1.13 ms ± 384 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.933990121985794
1.9095942141322215
748 µs ± 195 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9475402208390824
1.9494076334790447
576 µs ± 54.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8589765045801104
1.889592058420883
1.4 ms ± 659 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8563859315740332
1.8689975224271627
740 µs ± 77.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.928830260548582
1.8557579356544007
769 µs ± 68.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8437500234877205
1.9174548373736333
1.64 ms ± 439 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8850338692176746
2.0134683306049777
716 µs ± 51 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9338067281411887
1.9767214145036887
919 µs ± 537 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8205574093839865
1.7944354402306084
1.67 ms ± 530 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9737491365943363
1.951982798855767
1.53 ms ± 258 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9470174094386663
1.8578685241325357
1.4 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0005854999605033
1.9212340122496618
1.51 ms ± 242 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9273271404185992
1.8973621424806635
1.61 ms ± 604 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0220764708611707
1.9472312379222265
2.08 ms ± 548 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9248811191718072
1.8928291855604598
1.06 ms ± 286 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8290740848746385
1.8897263331754375
694 µs ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9799797827702561
1.952145411391983
791 µs ± 84.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9405966189131413
1.9026507286851595
1.19 ms ± 451 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9772764762220023
1.9203726525414786
754 µs ± 149 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0202258803759694
1.849928464204487
741 µs ± 49.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0584196314490577
1.8841021134541438
748 µs ± 28.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9354598685244473
1.9313836408226328
875 µs ± 274 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0344390971829374
1.8982657302884327
970 µs ± 328 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9304285291095017
1.865926431269422
708 µs ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8619103690449218
1.9286916051702754
759 µs ± 34 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0140179780608727
1.8924410298302774
994 µs ± 203 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8577751186154285
1.88701888782196
705 µs ± 37.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9723147845176738
1.8821015113591535
726 µs ± 52.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9939406608492074
1.8880703013393
712 µs ± 37.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0408800713205586
1.9248629503590609
704 µs ± 37.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9584778676205794
1.8842226406928675
718 µs ± 61 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8779676454323607
1.9256547680245397
687 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0903055260322778
1.8867120123836676
737 µs ± 57.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.950257682528716
1.879053173611434
790 µs ± 89.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0122997217021872
1.8890781635072065
747 µs ± 89.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9679689464883663
1.872122699531169
689 µs ± 12.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9096320417130161
1.923981415466368
748 µs ± 170 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0781932512358905
1.9057818053352777
779 µs ± 137 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0215196987892357
1.9189415812174135
777 µs ± 101 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9232977516007778
1.8273955416564407
785 µs ± 62.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9041884309796195
1.9066577343610374
1.04 ms ± 339 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9252223734993067
1.8814625440993034
917 µs ± 290 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8495598868404447
1.8300689253833733
743 µs ± 124 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0765161414654125
1.8574254538123613
921 µs ± 318 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8161981092285884
1.8785336282001959
852 µs ± 171 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8397428807457661
1.9578194249892404
862 µs ± 213 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0174859132107457
1.8849735639470198
764 µs ± 92.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8959106723444588
1.9333336705675197
783 µs ± 105 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.7916593306788984
1.9545142441497103
746 µs ± 49.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8812807784690273
1.8874628882436206
697 µs ± 11.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8673166863772079
1.9080597422636205
693 µs ± 10.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9367052052889462
1.885876287477752
862 µs ± 159 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9497331474099853
1.8672346413971976
716 µs ± 74.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9266406126684528
1.9517809122012415
727 µs ± 78.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9562966072014943
1.8804970072650735
740 µs ± 91.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9132301479427516
1.869284506246696
808 µs ± 118 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.923745649712572
1.8763534495471965
779 µs ± 67 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9287994767930432
1.8716414936200356
848 µs ± 136 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.977270467078829
1.947800409801077
806 µs ± 142 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9244729602885968
1.9226978447893457
538 µs ± 63.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9311189691484028
1.8837286976574859
530 µs ± 72.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8594519610091946
1.9563857351539964
583 µs ± 81.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9406141294930126
1.9025527526374169
501 µs ± 82 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.060407851595865
1.8790415774849558
584 µs ± 186 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9263146702991807
1.9159754127656088
464 µs ± 6.96 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.902627280762216
1.8929628711058266
559 µs ± 49.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.940825833072022
1.9478878029088262
521 µs ± 76.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9345100510488324
1.9208407501589857
535 µs ± 58.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0616573458438032
1.893844467893682
560 µs ± 73 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9259801183968952
1.9219868429525562
593 µs ± 122 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9248751926178794
1.8443098633578021
519 µs ± 96.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9872254701426095
1.9010606743345368
592 µs ± 146 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.966836792989094
1.9708824559631077
550 µs ± 74.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.142821965169572
1.9081736991446854
516 µs ± 58.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0132293465630695
1.9089849085528217
560 µs ± 88.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0953773163963214
1.946396226757867
502 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9454524901139612
1.8971714273008042
517 µs ± 51.3 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8541163678880714
1.909320854049065
810 µs ± 488 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8895323477510916
1.9089993747741696
504 µs ± 26.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9170592551570774
1.86418831198526
551 µs ± 65.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8480199192902722
1.9074352452706718
502 µs ± 27.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9042433310454134
1.9238371265096013
528 µs ± 28.5 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9203949047458799
1.8451912478345283
694 µs ± 142 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9110383512457658
1.8905749424872877
709 µs ± 123 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0010712461291937
1.8730913324179843
588 µs ± 44.8 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9278916565855064
1.8549047120428006
572 µs ± 57.1 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.1454243893281695
1.954262217001213
575 µs ± 77.6 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.018094793572398
1.8258876271142297
751 µs ± 123 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9738950121441763
1.8454200185068992
912 µs ± 214 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9080064857922512
1.899642305398118
734 µs ± 211 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0583399438882457
1.8884429617789698
545 µs ± 43.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9450661295876013
1.8795711746834745
1.37 ms ± 480 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9573964358834333
1.9593919496591903
755 µs ± 121 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9454227109527611
1.9336313835148111
864 µs ± 257 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8791163470909731
1.9229827155695596
590 µs ± 86.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.848623945255243
1.8983289871995557
626 µs ± 128 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0724451496139542
1.880654529439734
531 µs ± 38.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.949907274105624
1.8313554720358822
884 µs ± 407 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9016927240886459
1.881743624186996
591 µs ± 73.2 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0428522975502288
1.860607902740477
769 µs ± 178 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8979421583723741
1.8595749389028327
556 µs ± 92.9 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8881372028405056
1.8886899790362435
541 µs ± 71.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9787504517363712
1.9174609932807847
548 µs ± 63.4 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9515102024073514
1.8567866536422881
501 µs ± 21.7 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9102032994396412
1.9230050099499507
672 µs ± 191 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
1.0133132850048843
1.900949107371182
699 µs ± 376 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9247467938890921
1.9279616901206698
1.74 ms ± 381 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.9512289828586543
1.9188489050607411
The slowest run took 4.12 times longer than the fastest. This could mean that an intermediate result is being cached.
856 µs ± 500 µs per loop (mean ± std. dev. of 7 runs, 50 loops each)
0.8319935173995281
1.8610914498415747
In [16]:
plt.figure(figsize=(16,8))


plt.subplot(1,2,1)
plt.contourf(wrange, wrange, evaluate_discriminator(w.T)[0].reshape(300,300).T)

W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'g.')

W = sample_prior(prior_variance, 100)
plt.plot(W[:,0],W[:,1],'r.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('estimated log density ratio $\Phi^{-1}(D)$')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);

plt.title('log likelihood');
In [12]:
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(1000)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true log posterior');
In [13]:
sns.set_style('whitegrid')
plt.subplot(1,2,2)

W = sample_generator(5000)
plot = sns.kdeplot(W[:,0],W[:,1])
plt.axis('square')
plot.set(xlim=(wmin,wmax))
plot.set(ylim=(wmin,wmax))
plt.title('kde of approximate posterior')

plt.subplot(1,2,1)
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true posterior');
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 7)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
from tqdm import tnrange
Using TensorFlow backend.
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([  .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4729044328 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4727278616 dense_1: Denseinput:output:(None, 2)(None, 10)4729044328->4727278616 4729043432 dense_2: Denseinput:output:(None, 10)(None, 20)4727278616->4729043432 4729057176 logit: Denseinput:output:(None, 20)(None, 1)4729043432->4729057176 4721891144 activation_1: Activationinput:output:(None, 1)(None, 1)4729057176->4721891144
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.53143018484115601, 1.0]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4732353784 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4732435816 dense_3: Denseinput:output:(None, 3)(None, 10)4732353784->4732435816 4732434640 dense_4: Denseinput:output:(None, 10)(None, 20)4732435816->4732434640 4716115280 dense_5: Denseinput:output:(None, 20)(None, 2)4732434640->4716115280
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(200, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(200, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
Out[37]:
(-5, 5)
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=200, steps_per_epoch=15):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=60, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [44]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [45]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[45]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [46]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[46]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [47]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [48]:
sess = K.get_session()
In [49]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[49]:
True
In [50]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [51]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [52]:
elbo = make_elbo(ratio_estimator)
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [54]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [55]:
inference.compile(loss=inference_loss, 
                  optimizer=Adam(lr=LEARNING_RATE))
In [56]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [57]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[57]:
<tf.Tensor 'concat:0' shape=(200, 3) dtype=float32>
In [58]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[58]:
-3.8279114
In [59]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/200 [===>..........................] - ETA: 0s
Out[59]:
3.8279113578796387

Training

In [60]:
for epoch in tnrange(200, desc='epoch'):

    set_trainable(ratio_estimator, False)

    for _ in tnrange(1, desc='generator'):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, np.tile(y, reps=(BATCH_SIZE, 1)))

    set_trainable(discriminator, True)

    for _ in tnrange(3*50, desc='discriminator'):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)

In [61]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [62]:
metrics = discriminator.evaluate(inputs, targets)
 32/400 [=>............................] - ETA: 0s
In [63]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [64]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.gray)

ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

My PhD in Weeks

This is your PhD, and it's ending one week at a time.
  1 2 3 4
1        
       
       
       
5        
       
       
       
       
10        
       
       
       
       
15        
       
       
       
       
20        
       
       
       
       
25        
       
      x
      x
      x
30       x
      x
      x
      x
      x
35       x
      x
      x
      x
      x
40       x
      x
      x
      x
      x
45       x
      x
      x
      x
      x
50       x
      x
      x

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 6)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4766144608 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4766145224 dense_1: Denseinput:output:(None, 2)(None, 10)4766144608->4766145224 4767336208 dense_2: Denseinput:output:(None, 10)(None, 20)4766145224->4767336208 4766143992 logit: Denseinput:output:(None, 20)(None, 1)4767336208->4766143992 4766535464 activation_1: Activationinput:output:(None, 1)(None, 1)4766143992->4766535464
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.72363483905792236, 0.40000000596046448]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4769765752 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4769787520 dense_3: Denseinput:output:(None, 3)(None, 10)4769765752->4769787520 4769190184 dense_4: Denseinput:output:(None, 10)(None, 20)4769787520->4769190184 4769949288 dense_5: Denseinput:output:(None, 20)(None, 2)4769190184->4769949288
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Evidence lower bound

In [44]:
def set_trainable(model, trainable):
    """inorder traversal"""
    model.trainable = trainable

    if isinstance(model, Model): # i.e. has layers
        for layer in model.layers:
            set_trainable(layer, trainable)
In [45]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[45]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [46]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[46]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [47]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [48]:
sess = K.get_session()
In [49]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[49]:
True
In [50]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [51]:
def make_elbo(ratio_estimator):
    
    set_trainable(ratio_estimator, False)
    
    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [52]:
elbo = make_elbo(ratio_estimator)
In [53]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [59]:
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
In [60]:
inference.compile(loss=inference_loss, 
                  optimizer='adam')
In [61]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [62]:
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[62]:
<tf.Tensor 'concat_1:0' shape=(128, 3) dtype=float32>
In [63]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[63]:
-3.9920437
In [64]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/128 [======>.......................] - ETA: 0s
Out[64]:
3.9920437335968018

Training

In [70]:
for epoch in range(3*200):

    set_trainable(ratio_estimator, False)

    for _ in range(1):

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        metrics_inference = inference.train_on_batch(eps, np.tile(y, reps=(BATCH_SIZE, 1)))

    set_trainable(discriminator, True)

    for _ in range(3*50):

        w_sample_prior = prior.rvs(size=BATCH_SIZE)

        eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))

        metrics_discrim = discriminator.train_on_batch(inputs, targets)
In [71]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [72]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [73]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [74]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 5)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 4796441544 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)4797073168 dense_1: Denseinput:output:(None, 2)(None, 10)4796441544->4797073168 4796631920 dense_2: Denseinput:output:(None, 10)(None, 20)4797073168->4796631920 4796443560 logit: Denseinput:output:(None, 20)(None, 1)4796631920->4796443560 4795146192 activation_1: Activationinput:output:(None, 1)(None, 1)4796443560->4795146192
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.63304531574249268, 0.60000002384185791]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 4783101656 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)4797501680 dense_3: Denseinput:output:(None, 3)(None, 10)4783101656->4797501680 4797502688 dense_4: Denseinput:output:(None, 10)(None, 20)4797501680->4797502688 4799943456 dense_5: Denseinput:output:(None, 20)(None, 2)4797502688->4799943456
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Inference Model Training

In [44]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[44]:
<tf.Tensor 'Sigmoid:0' shape=(300, 300, 3) dtype=float32>
In [45]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[45]:
<tf.Tensor 'mul_33:0' shape=(300, 300, 3) dtype=float32>
In [46]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [47]:
sess = K.get_session()
In [48]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[48]:
True
In [49]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [50]:
def make_elbo(ratio_estimator):

    def elbo(y_true, w_sample):
        kl_estimate = ratio_estimator(w_sample)
        y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
        log_likelihood = - K.binary_crossentropy(y_pred, y_true, 
                                                 from_logits=True)
        return K.mean(log_likelihood-kl_estimate, axis=-1)

    return elbo
In [51]:
elbo = make_elbo(ratio_estimator)
In [52]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [94]:
inference.compile(loss=make_elbo(ratio_estimator), 
                  optimizer='adam')
In [95]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [96]:
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
 32/128 [======>.......................] - ETA: 0s
Out[96]:
-3.2424654364585876
In [105]:
# equiv. to use of tile above
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0), 
                           axis=0, rep=BATCH_SIZE)
y_true
Out[105]:
<tf.Tensor 'concat_10:0' shape=(128, 3) dtype=float32>
In [107]:
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
Out[107]:
-3.2424655

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 4)

In [177]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [178]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
In [179]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [180]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [181]:
K.tf.__version__
Out[181]:
'1.2.1'
In [182]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [183]:
w_min, w_max = -5, 5
In [184]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [185]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[185]:
(300, 300, 2)
In [186]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [187]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[187]:
(300, 300)
In [188]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [189]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [190]:
X = np.vstack((x1, x2, x3))
X.shape
Out[190]:
(3, 2)
In [191]:
y1 = 1
y2 = 1
y3 = 0
In [192]:
y = np.stack((y1, y2, y3))
y.shape
Out[192]:
(3,)
In [193]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [194]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[194]:
(300, 300, 3)
In [195]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [196]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [197]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 140297740092864 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)140297740092640 dense_1: Denseinput:output:(None, 2)(None, 10)140297740092864->140297740092640 140297740091520 dense_2: Denseinput:output:(None, 10)(None, 20)140297740092640->140297740091520 140297740093312 logit: Denseinput:output:(None, 20)(None, 1)140297740091520->140297740093312 140297740622592 activation_1: Activationinput:output:(None, 1)(None, 1)140297740093312->140297740622592
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.63745558261871338, 0.80000001192092896]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 140297740961440 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)140297740025472 dense_3: Denseinput:output:(None, 3)(None, 10)140297740961440->140297740025472 140297740023568 dense_4: Denseinput:output:(None, 10)(None, 20)140297740025472->140297740023568 140297740591512 dense_5: Denseinput:output:(None, 20)(None, 2)140297740023568->140297740591512
In [31]:
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
Out[31]:
(128, 2)
In [32]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
Out[32]:
(128, 2)
In [33]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [36]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [37]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
Discriminator pre-training
In [38]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

    ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [39]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[39]:
In [40]:
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
In [41]:
metrics = discriminator.evaluate(inputs, targets)
 32/256 [==>...........................] - ETA: 0s
In [42]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
In [43]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap='magma')

ax.scatter(*inputs.T, c=targets, alpha=.8, cmap='coolwarm')

train_info = dict(zip(discriminator.metrics_names, metrics))

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

ax.text(0.05, 0.05, 
        ('accuracy: {binary_accuracy:.2f}\n'        
         'loss: {loss:.2f}').format(**train_info), 
        transform=ax.transAxes, bbox=props)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Inference Model Training

In [214]:
y_pred = K.sigmoid(K.dot(
    K.constant(w_grid),
    K.transpose(K.constant(X))))
y_pred
Out[214]:
<tf.Tensor 'Sigmoid_13:0' shape=(300, 300, 3) dtype=float32>
In [215]:
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
Out[215]:
<tf.Tensor 'mul_50:0' shape=(300, 300, 3) dtype=float32>
In [207]:
llhs_keras = - K.binary_crossentropy(
                   y_pred, 
                   y_true, 
                   from_logits=False)
In [208]:
sess = K.get_session()
In [209]:
np.sum(llhs, axis=-1)
Out[209]:
array([[-20.08, -20.01, -19.94, ...,  -2.72,  -2.69,  -2.66],
       [-20.02, -19.95, -19.88, ...,  -2.68,  -2.64,  -2.61],
       [-19.95, -19.88, -19.81, ...,  -2.63,  -2.6 ,  -2.56],
       ..., 
       [-15.1 , -15.03, -14.96, ...,  -2.55,  -2.52,  -2.49],
       [-15.13, -15.06, -14.99, ...,  -2.59,  -2.56,  -2.53],
       [-15.16, -15.09, -15.02, ...,  -2.64,  -2.61,  -2.58]])
In [210]:
np.allclose(np.sum(llhs, axis=-1),
            sess.run(K.sum(llhs_keras, axis=-1)))
Out[210]:
True
In [211]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)), 
            cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [216]:
def make_elbo(ratio_estimator):

    def elbo(y_true, y_pred):
        log_likelihood = - K.binary_crossentropy(y_pred, y_true)
        kl_estimate = K.mean(ratio_estimator(y_pred), axis=-1)
        return log_likelihood - kl_estimate
    
    return elbo
In [217]:
elbo = make_elbo(ratio_estimator)
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 3)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
In [3]:
plt.style.use('seaborn-notebook')
# display animation inline
plt.rc('animation', html='html5')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [7]:
w_min, w_max = -5, 5
In [8]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [9]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[11]:
(300, 300)
In [12]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [13]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [14]:
X = np.vstack((x1, x2, x3))
X.shape
Out[14]:
(3, 2)
In [15]:
y1 = 1
y2 = 1
y3 = 0
In [16]:
y = np.stack((y1, y2, y3))
y.shape
Out[16]:
(3,)
In [17]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [18]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[18]:
(300, 300, 3)
In [19]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [20]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [21]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap='magma')

ax.scatter(*X.T, c=y, cmap='coolwarm')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [22]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [23]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [24]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[24]:
G 140638038322816 dense_1_input: InputLayerinput:output:(None, 2)(None, 2)140638038320800 dense_1: Denseinput:output:(None, 2)(None, 10)140638038322816->140638038320800 140638038322592 dense_2: Denseinput:output:(None, 10)(None, 20)140638038320800->140638038322592 140638038321808 logit: Denseinput:output:(None, 20)(None, 1)140638038322592->140638038321808 140638038360976 activation_1: Activationinput:output:(None, 1)(None, 1)140638038321808->140638038360976
In [25]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [26]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [27]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[27]:
[0.48928388953208923, 0.80000001192092896]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [28]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_4 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_5 (Dense)              (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [29]:
phi = inference.trainable_weights
phi
Out[29]:
[<tf.Variable 'dense_3/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_3/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_4/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_4/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_5/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_5/bias:0' shape=(2,) dtype=float32_ref>]
In [30]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[30]:
G 140638039992360 dense_3_input: InputLayerinput:output:(None, 3)(None, 3)140638037305216 dense_3: Denseinput:output:(None, 3)(None, 10)140638039992360->140638037305216 140638037305048 dense_4: Denseinput:output:(None, 10)(None, 20)140638037305216->140638037305048 140638039956000 dense_5: Denseinput:output:(None, 20)(None, 2)140638037305048->140638039956000
In [31]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [32]:
w_posterior_samples = inference.predict(eps)
w_posterior_samples.shape
Out[32]:
(128, 2)
In [33]:
w_prior_samples = prior.rvs(size=BATCH_SIZE)
w_prior_samples.shape
Out[33]:
(128, 2)
In [34]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*w_posterior_samples.T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [35]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)

scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

t = ax.text(0.05, 0.85, 'step: 0', 
            transform=ax.transAxes, bbox=props)

plt.show()
Discriminator pre-training
In [36]:
def train_animate(epoch_num, batch_size=128, steps_per_epoch=20):

    for step in range(steps_per_epoch):

        w_sample_prior = prior.rvs(size=batch_size)

        eps = np.random.randn(batch_size, NOISE_DIM)
        w_sample_posterior = inference.predict(eps)

        inputs = np.vstack((w_sample_prior, w_sample_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))

        metrics = discriminator.train_on_batch(inputs, targets)

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

    ax.scatter(*w_sample_posterior.T, alpha=.8)
    ax.scatter(*w_sample_prior.T, alpha=.8)

    train_info = dict(zip(discriminator.metrics_names, metrics))
    train_info['epoch'] = epoch_num
    
    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    ax.text(0.05, 0.05, 
            ('epoch: {epoch:2d}\n'
             'accuracy: {binary_accuracy:.2f}\n'        
             'loss: {loss:.2f}').format(**train_info), 
            transform=ax.transAxes, bbox=props)

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    return ax
In [37]:
FuncAnimation(fig, train_animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[37]:
In [ ]:
 

Variational Inference with Implicit Approximate Inference Models (WIP Pt. 2)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
/Users/tiao/.virtualenvs/implicit/lib/python3.6/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated since IPython 4.0. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.
  "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning)
In [3]:
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
# display animation inline
plt.rc('animation', html='html5')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [86]:
w_min, w_max = -5, 5
In [87]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [88]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[88]:
(300, 300, 2)
In [89]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [90]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[90]:
(300, 300)
In [91]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [92]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [93]:
X = np.vstack((x1, x2, x3))
X.shape
Out[93]:
(3, 2)
In [94]:
y1 = 1
y2 = 1
y3 = 0
In [95]:
y = np.stack((y1, y2, y3))
y.shape
Out[95]:
(3,)
In [96]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [97]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[97]:
(300, 300, 3)
In [98]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [104]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [105]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.plot(*np.vstack((x1,x2,x3)).T, 'ro')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [106]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [107]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [108]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[108]:
G 4390601896 dense_6_input: InputLayerinput:output:(None, 2)(None, 2)4656746112 dense_6: Denseinput:output:(None, 2)(None, 10)4390601896->4656746112 4656743088 dense_7: Denseinput:output:(None, 10)(None, 20)4656746112->4656743088 4705728888 logit: Denseinput:output:(None, 20)(None, 1)4656743088->4705728888 4705708240 activation_2: Activationinput:output:(None, 1)(None, 1)4705728888->4705708240
In [109]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [110]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [111]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[111]:
[0.45842784643173218, 1.0]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [112]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_8 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_9 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_10 (Dense)             (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [113]:
phi = inference.trainable_weights
phi
Out[113]:
[<tf.Variable 'dense_8/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_8/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_9/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_9/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_10/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_10/bias:0' shape=(2,) dtype=float32_ref>]
In [114]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[114]:
G 4638359224 dense_8_input: InputLayerinput:output:(None, 3)(None, 3)4706841152 dense_8: Denseinput:output:(None, 3)(None, 10)4638359224->4706841152 4638452536 dense_9: Denseinput:output:(None, 10)(None, 20)4706841152->4638452536 4638609928 dense_10: Denseinput:output:(None, 20)(None, 2)4638452536->4638609928
In [115]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [116]:
w_posterior_samples = inference.predict(eps)
w_posterior_samples.shape
Out[116]:
(128, 2)
In [117]:
w_prior_samples = prior.rvs(size=BATCH_SIZE)
w_prior_samples.shape
Out[117]:
(128, 2)
In [120]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*w_posterior_samples.T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [122]:
fig, ax = plt.subplots(figsize=(5, 5))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)

scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

t = ax.text(0.05, 0.85, 'step: 0', 
            transform=ax.transAxes, bbox=props)

plt.show()
Discriminator pre-training
In [164]:
def prior_samples_gen(batch_size):

    while True:
        yield prior.rvs(size=batch_size)
In [165]:
def posterior_samples_gen(inference_model, batch_size):

    while True:
        eps = np.random.randn(batch_size, NOISE_DIM)
        yield inference_model.predict(eps)
In [172]:
def discriminator_data_gen(inference_model, batch_size):
    
    for samples_prior, samples_posterior in zip(prior_samples_gen(batch_size), 
                                                posterior_samples_gen(inference_model, batch_size)):
        inputs = np.vstack((samples_prior, samples_posterior))
        targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
        yield inputs, targets
In [184]:
h = discriminator.fit_generator(generator=discriminator_data_gen(inference, 128), steps_per_epoch=32, epochs=2)
Epoch 1/2
32/32 [==============================] - 0s - loss: 0.1257 - binary_accuracy: 0.9559     
Epoch 2/2
32/32 [==============================] - 0s - loss: 0.1169 - binary_accuracy: 0.9598     
In [185]:
h.history['loss'][-1]
Out[185]:
0.11688009253703058
In [167]:
metrics = discriminator.train_on_batch(D_input, D_labels)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-167-0dfc15023395> in <module>()
----> 1 metrics = discriminator.train_on_batch(D_input, D_labels)

NameError: name 'D_input' is not defined
In [36]:
def animate(step):

    ax.cla()

    w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
    w_grid_ratio = w_grid_ratio.reshape(300, 300)

    ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

    info_dict = dict(zip(discriminator.metrics_names, metrics))
    info_dict['step'] = step

    props = dict(boxstyle='round', facecolor='w', alpha=0.5)

    t = ax.text(0.05, 0.85, 'step: 0', 
                transform=ax.transAxes, bbox=props)

    scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
    scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)
    
    return ax
In [37]:
FuncAnimation(fig, animate, frames=50, 
              interval=200, # 5 fps
              blit=False)
Out[37]:
In [38]:
fig, ax = plt.subplots(figsize=(7, 7))

w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

cset = ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma, animate=True)

scatter_posterior = ax.scatter(*w_posterior_samples.T, alpha=.8)
scatter_prior = ax.scatter(*w_prior_samples.T, alpha=.8)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

props = dict(boxstyle='round', facecolor='w', alpha=0.5)

t = ax.text(0.05, 0.85, 'step: 0', 
            transform=ax.transAxes, bbox=props)

plt.show()
In [39]:
cset.collections
Out[39]:
<a list of 8 mcoll.PathCollection objects>
In [40]:
from matplotlib.collections import PatchCollection
In [ ]:
 
In [41]:
dir(cset.collections[0])
Out[41]:
['_A',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_agg_filter',
 '_alpha',
 '_animated',
 '_antialiaseds',
 '_axes',
 '_bcast_lwls',
 '_clipon',
 '_clippath',
 '_contains',
 '_edge_default',
 '_edgecolors',
 '_facecolors',
 '_factor',
 '_get_bool',
 '_get_value',
 '_gid',
 '_hatch',
 '_hatch_color',
 '_is_filled',
 '_is_stroked',
 '_label',
 '_linestyles',
 '_linewidths',
 '_mouseover',
 '_offset_position',
 '_offsets',
 '_oid',
 '_original_edgecolor',
 '_original_facecolor',
 '_path_effects',
 '_paths',
 '_picker',
 '_pickradius',
 '_prepare_points',
 '_prop_order',
 '_propobservers',
 '_rasterized',
 '_remove_method',
 '_set_edgecolor',
 '_set_facecolor',
 '_set_gc_clip',
 '_sizes',
 '_sketch',
 '_snap',
 '_stale',
 '_sticky_edges',
 '_transOffset',
 '_transform',
 '_transformSet',
 '_transforms',
 '_uniform_offsets',
 '_url',
 '_urls',
 '_us_linestyles',
 '_us_lw',
 '_visible',
 'add_callback',
 'add_checker',
 'aname',
 'autoscale',
 'autoscale_None',
 'axes',
 'callbacksSM',
 'changed',
 'check_update',
 'clipbox',
 'cmap',
 'colorbar',
 'contains',
 'convert_xunits',
 'convert_yunits',
 'draw',
 'eventson',
 'figure',
 'findobj',
 'format_cursor_data',
 'get_agg_filter',
 'get_alpha',
 'get_animated',
 'get_array',
 'get_axes',
 'get_children',
 'get_clim',
 'get_clip_box',
 'get_clip_on',
 'get_clip_path',
 'get_cmap',
 'get_contains',
 'get_cursor_data',
 'get_dashes',
 'get_datalim',
 'get_edgecolor',
 'get_edgecolors',
 'get_facecolor',
 'get_facecolors',
 'get_figure',
 'get_fill',
 'get_gid',
 'get_hatch',
 'get_label',
 'get_linestyle',
 'get_linestyles',
 'get_linewidth',
 'get_linewidths',
 'get_offset_position',
 'get_offset_transform',
 'get_offsets',
 'get_path_effects',
 'get_paths',
 'get_picker',
 'get_pickradius',
 'get_rasterized',
 'get_sizes',
 'get_sketch_params',
 'get_snap',
 'get_transform',
 'get_transformed_clip_path_and_affine',
 'get_transforms',
 'get_url',
 'get_urls',
 'get_visible',
 'get_window_extent',
 'get_zorder',
 'have_units',
 'hitlist',
 'is_figure_set',
 'is_transform_set',
 'mouseover',
 'norm',
 'pchanged',
 'pick',
 'pickable',
 'properties',
 'remove',
 'remove_callback',
 'set',
 'set_agg_filter',
 'set_alpha',
 'set_animated',
 'set_antialiased',
 'set_antialiaseds',
 'set_array',
 'set_axes',
 'set_clim',
 'set_clip_box',
 'set_clip_on',
 'set_clip_path',
 'set_cmap',
 'set_color',
 'set_contains',
 'set_dashes',
 'set_edgecolor',
 'set_edgecolors',
 'set_facecolor',
 'set_facecolors',
 'set_figure',
 'set_gid',
 'set_hatch',
 'set_label',
 'set_linestyle',
 'set_linestyles',
 'set_linewidth',
 'set_linewidths',
 'set_lw',
 'set_norm',
 'set_offset_position',
 'set_offsets',
 'set_path_effects',
 'set_paths',
 'set_picker',
 'set_pickradius',
 'set_rasterized',
 'set_sizes',
 'set_sketch_params',
 'set_snap',
 'set_transform',
 'set_url',
 'set_urls',
 'set_visible',
 'set_zorder',
 'stale',
 'stale_callback',
 'sticky_edges',
 'to_rgba',
 'update',
 'update_dict',
 'update_from',
 'update_scalarmappable',
 'zorder']
In [42]:
import matplotlib.patches as patches
In [43]:
from matplotlib.collections import PathCollection
In [44]:
fig, ax = plt.subplots(figsize=(7, 7))

ax.add_collection(cset.collections[4])

plt.show()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-44-d9b87f692bb4> in <module>()
      1 fig, ax = plt.subplots(figsize=(7, 7))
      2 
----> 3 ax.add_collection(cset.collections[4])
      4 
      5 plt.show()

~/.virtualenvs/implicit/lib/python3.6/site-packages/matplotlib/axes/_base.py in add_collection(self, collection, autolim)
   1754             collection.set_label('_collection%d' % len(self.collections))
   1755         self.collections.append(collection)
-> 1756         self._set_artist_props(collection)
   1757 
   1758         if collection.get_clip_path() is None:

~/.virtualenvs/implicit/lib/python3.6/site-packages/matplotlib/axes/_base.py in _set_artist_props(self, a)
    922     def _set_artist_props(self, a):
    923         """set the boilerplate props for artists added to axes"""
--> 924         a.set_figure(self.figure)
    925         if not a.is_transform_set():
    926             a.set_transform(self.transData)

~/.virtualenvs/implicit/lib/python3.6/site-packages/matplotlib/artist.py in set_figure(self, fig)
    648         # to more than one Axes
    649         if self.figure is not None:
--> 650             raise RuntimeError("Can not put single artist in "
    651                                "more than one figure")
    652         self.figure = fig

RuntimeError: Can not put single artist in more than one figure
In [134]:
fig, ax = plt.subplots(figsize=(7, 7))

cset = ax.contourf(np.linspace(-3, 3, 32), np.linspace(-3, 3, 32), np.random.randn(32, 32), cmap='magma')
scat = ax.scatter(*np.random.randn(2, 128), alpha=.8)

ax.set_xlim(-3, 3)
ax.set_ylim(-3, 3)

plt.show()
In [138]:
def animate(step):

    ax.cla()
    
    ax.contourf(np.linspace(-3, 3, 32), np.linspace(-3, 3, 32), np.random.randn(32, 32), cmap='magma')
    ax.scatter(*np.random.randn(2, 128), alpha=.8)

    ax.set_xlim(-3, 3)
    ax.set_ylim(-3, 3)
    
    return scat
In [139]:
FuncAnimation(fig, animate, frames=25, 
              interval=200) # 5 fps
Out[139]:
In [ ]: